1. Representing Time in Connectionist Networks

This notebook explores how time has been represented in artificial neural networks. There are complex, mathematical models. However, we examine three straightforward methods.

For this exploration, we will use the encoding method from this week's lab.

First, we define some text:

In [1]:
text = ("This is a test. Ok. What comes next? Depends? Yes. " +
        "This is also a way of testing prediction. Ok. " +
        "This is fine. Need lots of data. Ok?")

And an encoding methodology:

In [2]:
letters = list(set([letter for letter in text]))

def encode(letter):
    index = letters.index(letter)
    binary = [0] * len(letters)
    binary[index] = 1
    return binary

patterns = {letter: encode(letter) for letter in letters}

1.1 NETTalk (1987)

Consider the problem of reading. That is, given input text, produce a representation of pronouncing that text.

Consider the word "cache". The first "c" is pronounced /k/, but the second "c" is pronounced /sh/. Now consider the made-up word: Tacoche. How would you pronounce it? How do you even go about making a guess?

In 1987, the program NETTalk was created to attempt to solve this problem.

In [3]:
203 / 7
Out[3]:
29.0

Additional input units are called context or state units.

How is time represented in NETTalk?

Out[4]:
In [5]:
from conx import Network

class NETTalk(Network):
    def initialize_inputs(self):
        pass

    def inputs_size(self):
        # Return the number of inputs:
        return len(text)

    def get_letter(self, i):
        if 0 <= i < len(text):
            return text[i]
        else:
            return text[0]
    
    def get_inputs(self, i):
        inputs = []
        for index in range(i - 3, i + 3 + 1):
            letter = text[i]
            inputs += patterns[letter]
        return [inputs, inputs]
In [18]:
pattern_length = len(patterns[" "])
input_length = pattern_length * 7

nettalk = NETTalk(input_length, 80, input_length)
In [19]:
"".join([str(v) for v in nettalk.get_inputs(0)[0]])
Out[19]:
'00000000100000000000000000000000000001000000000000000000000000000010000000000000000000000000000100000000000000000000000000001000000000000000000000000000010000000000000000000000000000100000000000000000000'
In [8]:
nettalk.train(report_rate=10, max_training_epochs=500)
--------------------------------------------------
Training for max trails: 500 ...
Epoch: 0 TSS error: 13070.2553377 %correct: 0.0
Epoch: 10 TSS error: 716.730931536 %correct: 0.0
Epoch: 20 TSS error: 468.573515397 %correct: 0.0
Epoch: 30 TSS error: 431.95824528 %correct: 0.0
Epoch: 40 TSS error: 416.634038769 %correct: 0.0
Epoch: 50 TSS error: 265.945084501 %correct: 0.37593984962406013
Epoch: 60 TSS error: 261.9540547 %correct: 0.42105263157894735
Epoch: 70 TSS error: 251.242554953 %correct: 0.47368421052631576
Epoch: 80 TSS error: 249.358370138 %correct: 0.42857142857142855
Epoch: 90 TSS error: 240.833746482 %correct: 0.47368421052631576
Epoch: 100 TSS error: 238.423773943 %correct: 0.47368421052631576
Epoch: 110 TSS error: 236.585542687 %correct: 0.47368421052631576
Epoch: 120 TSS error: 235.335515939 %correct: 0.47368421052631576
Epoch: 130 TSS error: 234.6405665 %correct: 0.46616541353383456
Epoch: 140 TSS error: 232.128789819 %correct: 0.47368421052631576
Epoch: 150 TSS error: 223.954463761 %correct: 0.47368421052631576
Epoch: 160 TSS error: 222.844211745 %correct: 0.47368421052631576
Epoch: 170 TSS error: 218.780971297 %correct: 0.47368421052631576
Epoch: 180 TSS error: 212.348584825 %correct: 0.38345864661654133
Epoch: 190 TSS error: 208.880241851 %correct: 0.47368421052631576
Epoch: 200 TSS error: 200.669794731 %correct: 0.47368421052631576
Epoch: 210 TSS error: 199.049227988 %correct: 0.45112781954887216
Epoch: 220 TSS error: 185.857938381 %correct: 0.43609022556390975
Epoch: 230 TSS error: 184.198280768 %correct: 0.47368421052631576
Epoch: 240 TSS error: 183.701842181 %correct: 0.47368421052631576
Epoch: 250 TSS error: 182.849165753 %correct: 0.47368421052631576
Epoch: 260 TSS error: 180.944417168 %correct: 0.47368421052631576
Epoch: 270 TSS error: 174.454805672 %correct: 0.47368421052631576
Epoch: 280 TSS error: 170.033090104 %correct: 0.47368421052631576
Epoch: 290 TSS error: 169.799503757 %correct: 0.47368421052631576
Epoch: 300 TSS error: 169.685143824 %correct: 0.47368421052631576
Epoch: 310 TSS error: 169.584419763 %correct: 0.47368421052631576
Epoch: 320 TSS error: 169.335318947 %correct: 0.47368421052631576
Epoch: 330 TSS error: 168.368304961 %correct: 0.47368421052631576
Epoch: 340 TSS error: 156.421263067 %correct: 0.5488721804511278
Epoch: 350 TSS error: 152.673945599 %correct: 0.5488721804511278
Epoch: 360 TSS error: 148.779486623 %correct: 0.5639097744360902
Epoch: 370 TSS error: 148.566616051 %correct: 0.5639097744360902
Epoch: 380 TSS error: 148.3543419 %correct: 0.5639097744360902
Epoch: 390 TSS error: 141.709235563 %correct: 0.556390977443609
Epoch: 400 TSS error: 139.634931458 %correct: 0.5639097744360902
Epoch: 410 TSS error: 135.253393552 %correct: 0.5639097744360902
Epoch: 420 TSS error: 132.899621933 %correct: 0.5939849624060151
Epoch: 430 TSS error: 129.947022062 %correct: 0.5939849624060151
Epoch: 440 TSS error: 129.634026189 %correct: 0.5939849624060151
Epoch: 450 TSS error: 129.34233424 %correct: 0.5939849624060151
Epoch: 460 TSS error: 127.807039826 %correct: 0.5939849624060151
Epoch: 470 TSS error: 127.60060166 %correct: 0.5939849624060151
Epoch: 480 TSS error: 127.514805405 %correct: 0.5939849624060151
Epoch: 490 TSS error: 127.376012662 %correct: 0.5939849624060151
Epoch: 500 TSS error: 126.600145442 %correct: 0.6015037593984962
--------------------------------------------------
Epoch: 500 TSS error: 126.600145442 %correct: 0.6015037593984962
In [9]:
nettalk.get_history()
Out[9]:
[[0, 13070.255337690462, 0.0],
 [10, 716.73093153640559, 0.0],
 [20, 468.57351539725056, 0.0],
 [30, 431.95824528003442, 0.0],
 [40, 416.63403876878135, 0.0],
 [50, 265.9450845005108, 0.37593984962406013],
 [60, 261.95405469996177, 0.42105263157894735],
 [70, 251.24255495313702, 0.47368421052631576],
 [80, 249.35837013797794, 0.42857142857142855],
 [90, 240.83374648196263, 0.47368421052631576],
 [100, 238.42377394323839, 0.47368421052631576],
 [110, 236.5855426873446, 0.47368421052631576],
 [120, 235.33551593933615, 0.47368421052631576],
 [130, 234.64056650023889, 0.46616541353383456],
 [140, 232.12878981907664, 0.47368421052631576],
 [150, 223.95446376119079, 0.47368421052631576],
 [160, 222.84421174504089, 0.47368421052631576],
 [170, 218.78097129678011, 0.47368421052631576],
 [180, 212.34858482531087, 0.38345864661654133],
 [190, 208.88024185073471, 0.47368421052631576],
 [200, 200.66979473088767, 0.47368421052631576],
 [210, 199.0492279875393, 0.45112781954887216],
 [220, 185.85793838145366, 0.43609022556390975],
 [230, 184.19828076777677, 0.47368421052631576],
 [240, 183.70184218147062, 0.47368421052631576],
 [250, 182.8491657534002, 0.47368421052631576],
 [260, 180.94441716786105, 0.47368421052631576],
 [270, 174.45480567201344, 0.47368421052631576],
 [280, 170.03309010410015, 0.47368421052631576],
 [290, 169.79950375737079, 0.47368421052631576],
 [300, 169.68514382370628, 0.47368421052631576],
 [310, 169.58441976319799, 0.47368421052631576],
 [320, 169.3353189471139, 0.47368421052631576],
 [330, 168.36830496114825, 0.47368421052631576],
 [340, 156.42126306674979, 0.5488721804511278],
 [350, 152.67394559851809, 0.5488721804511278],
 [360, 148.77948662250174, 0.5639097744360902],
 [370, 148.56661605062305, 0.5639097744360902],
 [380, 148.35434190048989, 0.5639097744360902],
 [390, 141.70923556269494, 0.556390977443609],
 [400, 139.63493145836725, 0.5639097744360902],
 [410, 135.25339355186716, 0.5639097744360902],
 [420, 132.89962193267056, 0.5939849624060151],
 [430, 129.94702206152613, 0.5939849624060151],
 [440, 129.63402618887667, 0.5939849624060151],
 [450, 129.34233423992214, 0.5939849624060151],
 [460, 127.80703982616402, 0.5939849624060151],
 [470, 127.60060165960947, 0.5939849624060151],
 [480, 127.51480540540372, 0.5939849624060151],
 [490, 127.37601266242373, 0.5939849624060151],
 [500, 126.60014544226198, 0.6015037593984962]]
In [10]:
%matplotlib notebook

import matplotlib.pyplot as plt
plt.plot([item[1] for item in nettalk.get_history()[10:]])
Out[10]:
[<matplotlib.lines.Line2D at 0x7f4a522b3908>]
In [20]:
def winning_output(outputs):
    """
    Given outputs, what letter is this associated with?
    """
    value = max(outputs)
    index = list(outputs).index(value)
    for key,value in patterns.items():
        if value[index] == 1:
            return key
    return "?"
In [21]:
winning_output([0 for i in range(input_length)])
Out[21]:
'N'
In [22]:
import random

def select(outputs):
    index = 0
    partsum = 0.0
    sumFitness = sum(outputs)
    if sumFitness == 0:
        raise Exception("outputs has a sum of zero")
    spin = random.random() * sumFitness
    while index < len(outputs) - 1:
        score = outputs[index]
        if score < 0:
            raise Exception("Negative score: " + str(score))
        partsum += score
        if partsum >= spin:
            break
        index += 1
    return index
In [23]:
pattern_length
Out[23]:
29
In [24]:
test = [random.random() for i in range(pattern_length)]
In [25]:
select(test)
Out[25]:
26
In [26]:
def winning_output(outputs):
    """
    Given outputs, what letter is this associated with?
    """
    index = select(outputs)
    for key,value in patterns.items():
        if value[index] == 1:
            return key
    return "?"
In [27]:
winning_output(test)
Out[27]:
'l'
In [28]:
test2 = [0] * pattern_length
test2[10] = .5
test2[11] = .5
In [33]:
winning_output(test2)
Out[33]:
'd'

1.2 Jordan Net (1986)

In [34]:
from conx import Network

class Jordan(Network):
    def initialize_inputs(self):
        self.last_output = [0] * pattern_length 
        self.last_inputs = [0] * (pattern_length * 2)

    def inputs_size(self):
        # Return the number of inputs:
        return len(text)

    def get_letter(self, i):
        if 0 <= i < len(text):
            return text[i]
        else:
            return text[0]
    
    def get_inputs(self, i):
        #import pdb; pdb.set_trace()
        inputs = patterns[self.get_letter(i)] + list(self.last_output)
        targets = patterns[self.get_letter(i)]
        self.last_output = self.propagate(self.last_inputs)
        self.last_inputs = inputs
        return [inputs, targets]
In [35]:
jordan = Jordan(pattern_length * 2, 10, pattern_length)
In [36]:
jordan.train(report_rate=1, max_training_epochs=10)
--------------------------------------------------
Training for max trails: 10 ...
Epoch: 0 TSS error: 1746.01059228 %correct: 0.0
Epoch: 1 TSS error: 357.091669532 %correct: 0.0
Epoch: 2 TSS error: 125.155245219 %correct: 0.0
Epoch: 3 TSS error: 122.93406114 %correct: 0.0
Epoch: 4 TSS error: 121.433076871 %correct: 0.0
Epoch: 5 TSS error: 119.980575008 %correct: 0.0
Epoch: 6 TSS error: 118.31517078 %correct: 0.0
Epoch: 7 TSS error: 116.234642512 %correct: 0.0
Epoch: 8 TSS error: 113.57362279 %correct: 0.0
Epoch: 9 TSS error: 110.311626446 %correct: 0.0
Epoch: 10 TSS error: 106.69250351 %correct: 0.0
--------------------------------------------------
Epoch: 10 TSS error: 106.69250351 %correct: 0.0

1.3 Simple Recurrent Network (1990)

In [ ]:
from conx import SRN

class Elman(SRN):
    def initialize_inputs(self):
        pass
    
    def inputs_size(self):
        # Return the number of inputs:
        return len(text)

    def get_letter(self, i):
        if 0 <= i < len(text):
            return text[i]
        else:
            return text[0]
    
    def get_inputs(self, i):
        inputs = patterns[self.get_letter(i)]
        targets = patterns[self.get_letter(i + 1)]
        return [inputs, targets]
In [ ]:
elman = Elman(pattern_length, 10, pattern_length)
In [ ]:
elman.train(report_rate=1, max_training_epochs=10)